#%%
import os
import ast
import re
import shutil
import sys
import traceback
from typing import List, Dict, Set, Optional, Tuple, Any

def parse_args():
    """Parse command line arguments or use defaults."""
    args = [arg for arg in sys.argv[1:] if not arg.startswith('--f=')] # Filter out arguments starting with --f=

    if len(args) >= 2:
        source_dir = args[0]
        target_dir_base = args[1] #Base target dir to put 'refactored' inside
    elif len(args) == 1:
        source_dir = args[0]
        target_dir_base = os.path.dirname(args[0]) # Go up one level from source dir
    else:
        source_dir = "./strategies"  # Default source directory (assuming run from parent)
        target_dir_base = "." # Default target base dir (current dir)

    target_dir = os.path.join(target_dir_base, "refactored")

    # Convert to absolute paths before returning
    source_dir = os.path.abspath(source_dir) # Make source_dir absolute
    target_dir = os.path.abspath(target_dir) # Make target_dir absolute

    # Print directories being used
    print(f"Source directory: {source_dir}") # Already absolute now
    print(f"Target directory: {target_dir}") # Already absolute now

    return source_dir, target_dir

# Maps to track class information
class_info = {}  # class_name -> {"module": module_name, "bases": [base_classes], "definition": str}
module_classes = {}  # module_name -> [class_names]
module_imports = {}  # module_name -> [import_lines]
module_top_level_code = {}
processed_classes = set()  # Classes that have already been processed


def scan_modules(source_dir: str):
    """Scan modules using AST to extract imports, top-level code, and class definitions."""
    global class_info, module_classes, module_imports, module_top_level_code # Ensure we modify globals

    for filename in os.listdir(source_dir):
        if not filename.endswith('.py') or filename.startswith('__'):
            continue

        module_name = filename[:-3]
        file_path = os.path.join(source_dir, filename)
        print(f"\nScanning Module: {module_name}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
        except Exception as e:
            print(f"  ERROR: Could not read file {filename}: {e}")
            continue # Skip this file

        imports = []
        constants_and_assignments = []
        functions = []
        classes_in_module = []
        current_module_top_level_code = [] # Local list for current module

        try:
            # --- Start of the main AST processing block ---
            tree = ast.parse(content, filename=filename)
            if not isinstance(tree, ast.Module):
                 print(f"  Warning: Parsed content is not an ast.Module for {filename}")
                 continue

            for node in tree.body: # Iterate through top-level nodes
                try:
                    # Use ast.get_source_segment (Python 3.8+)
                    node_code = ast.get_source_segment(content, node)
                    if node_code is None:
                         # Fallback if segment not found
                         print(f"  Warning: Could not get source segment for a node in {filename}, using ast.unparse().")
                         node_code = ast.unparse(node)
                except Exception as segment_error:
                     # Fallback for older Python or potential issues getting segment
                     print(f"  Warning: Error getting source segment ({segment_error}), using ast.unparse().")
                     node_code = ast.unparse(node)


                current_module_top_level_code.append(node_code) # Add all top-level code in order

                # Also categorize for potential separate handling if needed later
                if isinstance(node, (ast.Import, ast.ImportFrom)):
                    imports.append(node_code)
                elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                    functions.append(node_code)
                elif isinstance(node, (ast.Assign, ast.AnnAssign, ast.AugAssign)):
                     constants_and_assignments.append(node_code)
                elif isinstance(node, ast.ClassDef):
                    class_name = node.name
                    classes_in_module.append(class_name)

                    # Extract base class names using AST nodes
                    bases = []
                    for base_node in node.bases:
                        try:
                            base_name_full = ast.unparse(base_node)
                            base_name_simple = base_name_full.split('.')[-1]
                            bases.append(base_name_simple)
                        except Exception:
                            print(f"  Warning: Could not determine base class name for '{ast.dump(base_node)}' in class {class_name}")

                    class_info[class_name] = {
                        "module": module_name,
                        "bases": bases,
                        "definition": node_code # Store exact source segment
                    }

            module_imports[module_name] = imports
            module_top_level_code[module_name] = current_module_top_level_code # Assign to global dict
            module_classes[module_name] = classes_in_module

            print(f"  AST Scan: Found {len(classes_in_module)} classes, {len(current_module_top_level_code)} top-level code blocks.")
            # --- End of main AST processing block ---

        except SyntaxError as e:
            print(f"  ERROR: Skipping module {module_name} due to SyntaxError: {e}")
        except Exception: 
             print(f"  ERROR: Skipping module {module_name} due to unexpected error during AST processing:")
             traceback.print_exc() # Print the full traceback to stderr


def get_all_base_classes(class_name: str, visited=None) -> Set[str]:
    """Get all base classes in the inheritance hierarchy (from the same codebase)."""
    if visited is None:
        visited = set()

    if class_name in visited:
        return set()

    visited.add(class_name)

    result = set()
    if class_name not in class_info:
        return result

    for base in class_info[class_name]["bases"]:
        if base in class_info:
            result.add(base)
            result.update(get_all_base_classes(base, visited))

    return result

def generate_class_files(target_dir: str):
    """Generate individual class files using AST-derived info."""
    global processed_classes

    os.makedirs(target_dir, exist_ok=True)
    all_class_names = list(class_info.keys())
    sorted_class_names = sorted(all_class_names, key=lambda x: len(get_all_base_classes(x)))

    for class_name in sorted_class_names:
        if class_name in processed_classes:
            continue

        # Ensure class_info has data for this class (might be missing if scan failed)
        if class_name not in class_info:
            print(f"  Skipping generation for {class_name}: Class info not found (scan likely failed).")
            continue

        info = class_info[class_name]
        module_name = info["module"]
        target_file = os.path.join(target_dir, f"{class_name}.py")

        # Check if module data exists (might be missing if scan failed)
        if module_name not in module_top_level_code:
            print(f"  Skipping generation for {class_name}: Module data for '{module_name}' not found (scan likely failed).")
            continue

        top_level_code_blocks = module_top_level_code[module_name]

        content_lines = []
        included_defs = set()
        classes_in_this_module = set(module_classes.get(module_name, []))

        # 1. Add all top-level code from the original module (excluding classes)
        for code_block in top_level_code_blocks:
             first_line = code_block.strip()
             # Simple check if the block starts like a class definition
             is_class_def = first_line.startswith("class ") or first_line.startswith("@") and "class " in code_block # Handle decorators
             if not is_class_def:
                 content_lines.append(code_block)
                 content_lines.append("") # Add spacing

        # 2. Add definitions of all required base classes
        base_classes = get_all_base_classes(class_name)
        sorted_bases = sorted(list(base_classes), key=lambda x: len(get_all_base_classes(x)))

        for base_name in sorted_bases:
            if base_name in class_info and base_name not in included_defs:
                 # Check if base class module data exists
                 base_module_name = class_info[base_name].get("module")
                 if base_module_name and base_module_name not in module_top_level_code:
                     print(f"  Warning: Skipping base class '{base_name}' for '{class_name}' because its module '{base_module_name}' data is missing.")
                     continue

                 content_lines.append(class_info[base_name]["definition"])
                 content_lines.append("")
                 included_defs.add(base_name)

        # 3. Add the main class definition
        if class_name not in included_defs:
            content_lines.append(info["definition"])
            included_defs.add(class_name)

        # Write the file
        try:
            with open(target_file, 'w', encoding='utf-8') as f:
                f.write('\n'.join(content_lines))

            print(f"Created {class_name}.py from {module_name}.py" +
                  (f" (including {len(base_classes)} explicitly added parent classes)" if base_classes else ""))

            processed_classes.add(class_name)
            processed_classes.update(included_defs) # Mark all included classes as processed

        except Exception as e:
            print(f"  ERROR: Failed to write file {target_file}: {e}")


def extract_strategy_names_from_list(source_dir: str) -> List[str]:
    """
    Extract strategy names from a text file containing a list of strategies.
    Looks for patterns like 'axelrod.strategies.module.StrategyName'
    """
    strategy_names = []

    # Look for text files that might contain the list
    for filename in os.listdir(source_dir):
        if filename.endswith('.txt'):
            file_path = os.path.join(source_dir, filename)

            with open(file_path, 'r') as f:
                content = f.read()
                matches = re.findall(r'axelrod\.strategies\.[^.]+\.([A-Za-z0-9_]+)', content)
                if matches:
                    strategy_names.extend(matches)

    return strategy_names

def verify_strategies(target_dir: str, expected_strategies: List[str]):
    """
    Verify that all expected strategies have corresponding Python files.
    Create placeholder files for any missing strategies.
    """
    missing_strategies = []

    for strategy in expected_strategies:
        strategy_file = os.path.join(target_dir, f"{strategy}.py")

        if not os.path.exists(strategy_file):
            missing_strategies.append(strategy)

    if missing_strategies:
        print(f"\nWARNING: Found {len(missing_strategies)} strategies missing from the outputs:")
        for strategy in missing_strategies:
            print(f"  - {strategy}")

        print("\nPlease check if these strategies exist in the source files.")
    else:
        print("\nAll expected strategies have been processed successfully!")


if __name__ == "__main__":
    # Parse arguments
    source_dir, target_dir = parse_args()

    # Try to extract expected strategy names from text files
    expected_strategies = extract_strategy_names_from_list(source_dir)
    if expected_strategies:
        print(f"Found {len(expected_strategies)} expected strategies in text list.")

    # Scan modules using AST and extract classes/code
    scan_modules(source_dir)

    # Generate class files based on AST analysis
    generate_class_files(target_dir)

    # Print summary
    total_classes = len(class_info)
    # Count actual files created as a proxy for processed count now
    try:
        created_files_count = len([f for f in os.listdir(target_dir) if f.endswith('.py')])
    except FileNotFoundError:
        created_files_count = 0

    print(f"\nRefactoring complete!")
    print(f"Scanned modules containing approximately {total_classes} classes.") # Class info might be incomplete if scans failed
    print(f"Created {created_files_count} class files in: {os.path.abspath(target_dir)}")

    # Verify all expected strategies were processed
    if expected_strategies:
        verify_strategies(target_dir, expected_strategies)
#%%
